# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Code to train model.
"""
from __future__ import annotations

import dataclasses
import functools
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union

import numpy as np
import torch
from rich.console import Console
from torch.cuda.amp.grad_scaler import GradScaler
from typing_extensions import Literal

from nerfstudio.configs.experiment_config import ExperimentConfig
from nerfstudio.engine.callbacks import (
    TrainingCallback,
    TrainingCallbackAttributes,
    TrainingCallbackLocation,
)
from nerfstudio.engine.optimizers import Optimizers
from nerfstudio.pipelines.base_pipeline import VanillaPipeline
from nerfstudio.utils import profiler, writer
from nerfstudio.utils.decorators import (
    check_eval_enabled,
    check_main_thread,
    check_viewer_enabled,
)
from nerfstudio.utils.misc import step_check
from nerfstudio.utils.writer import EventName, TimeWriter
from nerfstudio.viewer.server import viewer_utils
from MSTH.utils import Timer

CONSOLE = Console(width=120)

TRAIN_INTERATION_OUTPUT = Tuple[  # pylint: disable=invalid-name
    torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]
]
TORCH_DEVICE = Union[torch.device, str]  # pylint: disable=invalid-name

from MSTH.video_pipeline import VideoPipeline, VideoPipelineConfig


@dataclass
class TrainerConfig(ExperimentConfig):
    """Configuration for training regimen"""

    _target: Type = field(default_factory=lambda: Trainer)
    """target class to instantiate"""
    steps_per_save: int = 1000
    """Number of steps between saves."""
    steps_per_eval_batch: int = 500
    """Number of steps between randomly sampled batches of rays."""
    steps_per_eval_image: int = 500
    """Number of steps between single eval images."""
    steps_per_eval_all_images: int = 25000
    """Number of steps between eval all images."""
    max_num_iterations: int = 1000000
    """Maximum number of iterations to run."""
    mixed_precision: bool = False
    """Whether or not to use mixed precision for training."""
    save_only_latest_checkpoint: bool = True
    """Whether to only save the latest checkpoint or all checkpoints."""
    # optional parameters if we want to resume training
    load_dir: Optional[Path] = None
    """Optionally specify a pre-trained model directory to load from."""
    load_step: Optional[int] = None
    """Optionally specify model step to load from; if none, will find most recent model in load_dir."""
    load_config: Optional[Path] = None
    """Path to config YAML file."""
    log_gradients: bool = False
    """Optionally log gradients during training"""


class Trainer:
    """Trainer class

    Args:
        config: The configuration object.
        local_rank: Local rank of the process.
        world_size: World size of the process.

    Attributes:
        config: The configuration object.
        local_rank: Local rank of the process.
        world_size: World size of the process.
        device: The device to run the training on.
        pipeline: The pipeline object.
        optimizers: The optimizers object.
        callbacks: The callbacks object.
    """

    pipeline: VanillaPipeline
    optimizers: Optimizers
    callbacks: List[TrainingCallback]

    def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int = 1) -> None:
        self.config = config
        self.local_rank = local_rank
        self.world_size = world_size
        self.device: TORCH_DEVICE = "cpu" if world_size == 0 else f"cuda:{local_rank}"
        self.mixed_precision: bool = self.config.mixed_precision
        if self.device == "cpu":
            self.mixed_precision = False
            CONSOLE.print("Mixed precision is disabled for CPU training.")
        self._start_step: int = 0
        # optimizers
        self.grad_scaler = GradScaler(enabled=self.mixed_precision)

        self.base_dir: Path = config.get_base_dir()
        # directory to save checkpoints
        self.checkpoint_dir: Path = config.get_checkpoint_dir()
        CONSOLE.log(f"Saving checkpoints to: {self.checkpoint_dir}")

        self.viewer_state = None

    def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
        """Setup the Trainer by calling other setup functions.

        Args:
            test_mode:
                'val': loads train/val datasets into memory
                'test': loads train/test datasets into memory
                'inference': does not load any dataset into memory
        """
        self.pipeline = self.config.pipeline.setup(
            device=self.device, test_mode=test_mode, world_size=self.world_size, local_rank=self.local_rank
        )
        self.optimizers = self.setup_optimizers()

        self._load_checkpoint()

        self.callbacks = self.pipeline.get_training_callbacks(
            TrainingCallbackAttributes(
                optimizers=self.optimizers,  # type: ignore
                grad_scaler=self.grad_scaler,  # type: ignore
                pipeline=self.pipeline,  # type: ignore
            )
        )

        # set up viewer if enabled
        viewer_log_path = self.base_dir / self.config.viewer.relative_log_filename
        self.viewer_state, banner_messages = None, None
        if self.config.is_viewer_enabled() and self.local_rank == 0:
            datapath = self.pipeline.datamanager.get_datapath()
            if datapath is None:
                datapath = self.base_dir
            self.viewer_state, banner_messages = viewer_utils.setup_viewer(
                self.config.viewer, log_filename=viewer_log_path, datapath=datapath
            )
        self._check_viewer_warnings()
        # set up writers/profilers if enabled
        writer_log_path = self.base_dir / self.config.logging.relative_log_dir
        writer.setup_event_writer(
            self.config.is_wandb_enabled(), self.config.is_tensorboard_enabled(), log_dir=writer_log_path
        )
        writer.setup_local_writer(
            self.config.logging, max_iter=self.config.max_num_iterations, banner_messages=banner_messages
        )
        writer.put_config(name="config", config_dict=dataclasses.asdict(self.config), step=0)
        profiler.setup_profiler(self.config.logging)

    def setup_optimizers(self, reset_step=None) -> Optimizers:
        """Helper to set up the optimizers

        Returns:
            The optimizers object given the trainer config.
        """
        optimizer_config = self.config.optimizers.copy()
        # if reset_step is not None:
            # for _k, _v in optimizer_config.items():
                # optimizer_config[_k]['scheduler'].max_steps = optimizer_config[_k]['scheduler'].max_steps - reset_step

        param_groups = self.pipeline.get_param_groups()
        camera_optimizer_config = self.config.pipeline.datamanager.camera_optimizer
        if camera_optimizer_config is not None and camera_optimizer_config.mode != "off":
            assert camera_optimizer_config.param_group not in optimizer_config
            optimizer_config[camera_optimizer_config.param_group] = {
                "optimizer": camera_optimizer_config.optimizer,
                "scheduler": camera_optimizer_config.scheduler,
            }
        return Optimizers(optimizer_config, param_groups)
    
    # def reset_upsample_optimizer(self) -> Optimizers:

    def train(self) -> None:
        """Train the model."""
        assert self.pipeline.datamanager.train_dataset is not None, "Missing DatsetInputs"

        self.pipeline.datamanager.train_dataparser_outputs.save_dataparser_transform(
            self.base_dir / "dataparser_transforms.json"
        )

        self._init_viewer_state()
        with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
            num_iterations = self.config.max_num_iterations
            step = 0
            for step in range(self._start_step, self._start_step + num_iterations):
                with TimeWriter(writer, EventName.ITER_TRAIN_TIME, step=step) as train_t:
                    self.pipeline.train()

                    # training callbacks before the training iteration
                    for callback in self.callbacks:
                        callback.run_callback_at_location(
                            step, location=TrainingCallbackLocation.BEFORE_TRAIN_ITERATION
                        )

                    # time the forward pass
                    loss, loss_dict, metrics_dict = self.train_iteration(step)

                    # training callbacks after the training iteration
                    for callback in self.callbacks:
                        callback.run_callback_at_location(step, location=TrainingCallbackLocation.AFTER_TRAIN_ITERATION)

                # Skip the first two steps to avoid skewed timings that break the viewer rendering speed estimate.
                if step > 1:
                    writer.put_time(
                        name=EventName.TRAIN_RAYS_PER_SEC,
                        duration=self.pipeline.datamanager.get_train_rays_per_batch() / train_t.duration,
                        step=step,
                        avg_over_steps=True,
                    )

                self._update_viewer_state(step)

                # a batch of train rays
                if step_check(step, self.config.logging.steps_per_log, run_at_zero=True):
                    writer.put_scalar(name="Train Loss", scalar=loss, step=step)
                    writer.put_dict(name="Train Loss Dict", scalar_dict=loss_dict, step=step)
                    writer.put_dict(name="Train Metrics Dict", scalar_dict=metrics_dict, step=step)

                # Do not perform evaluation if there are no validation images
                if self.pipeline.datamanager.eval_dataset:
                    self.eval_iteration(step)

                if step_check(step, self.config.steps_per_save):
                    self.save_checkpoint(step)

                writer.write_out_storage()

        # save checkpoint at the end of training
        self.save_checkpoint(step)

        # write out any remaining events (e.g., total train time)
        writer.write_out_storage()

        CONSOLE.rule()
        CONSOLE.print("[bold green]:tada: :tada: :tada: Training Finished :tada: :tada: :tada:", justify="center")
        if not self.config.viewer.quit_on_train_completion:
            CONSOLE.print("Use ctrl+c to quit", justify="center")
            self._always_render(step)

    @check_main_thread
    def _always_render(self, step: int) -> None:
        if self.viewer_state is not None:
            while True:
                self.viewer_state.vis["renderingState/isTraining"].write(False)
                self._update_viewer_state(step)

    @check_main_thread
    def _check_viewer_warnings(self) -> None:
        """Helper to print out any warnings regarding the way the viewer/loggers are enabled"""
        if (
            self.config.is_viewer_enabled()
            and not self.config.is_tensorboard_enabled()
            and not self.config.is_wandb_enabled()
        ):
            string: str = (
                "[NOTE] Not running eval iterations since only viewer is enabled.\n"
                "Use [yellow]--vis {wandb, tensorboard, viewer+wandb, viewer+tensorboard}[/yellow] to run with eval."
            )
            CONSOLE.print(f"{string}")

    @check_viewer_enabled
    def _init_viewer_state(self) -> None:
        """Initializes viewer scene with given train dataset"""
        assert self.viewer_state and self.pipeline.datamanager.train_dataset
        self.viewer_state.init_scene(
            dataset=self.pipeline.datamanager.train_dataset,
            start_train=self.config.viewer.start_train,
        )
        if not self.config.viewer.start_train:
            self._always_render(self._start_step)

    @check_viewer_enabled
    def _update_viewer_state(self, step: int) -> None:
        """Updates the viewer state by rendering out scene with current pipeline
        Returns the time taken to render scene.

        Args:
            step: current train step
        """
        assert self.viewer_state is not None
        with TimeWriter(writer, EventName.ITER_VIS_TIME, step=step) as _:
            num_rays_per_batch: int = self.pipeline.datamanager.get_train_rays_per_batch()
            try:
                self.viewer_state.update_scene(self, step, self.pipeline.model, num_rays_per_batch)
            except RuntimeError:
                time.sleep(0.03)  # sleep to allow buffer to reset
                assert self.viewer_state.vis is not None
                self.viewer_state.vis["renderingState/log_errors"].write(
                    "Error: GPU out of memory. Reduce resolution to prevent viewer from crashing."
                )

    @check_viewer_enabled
    def _update_viewer_rays_per_sec(self, train_t: TimeWriter, vis_t: TimeWriter, step: int) -> None:
        """Performs update on rays/sec calculation for training

        Args:
            train_t: timer object carrying time to execute total training iteration
            vis_t: timer object carrying time to execute visualization step
            step: current step
        """
        train_num_rays_per_batch: int = self.pipeline.datamanager.get_train_rays_per_batch()
        writer.put_time(
            name=EventName.TRAIN_RAYS_PER_SEC,
            duration=train_num_rays_per_batch / (train_t.duration - vis_t.duration),
            step=step,
            avg_over_steps=True,
        )

    def _load_checkpoint(self) -> None:
        """Helper function to load pipeline and optimizer from prespecified checkpoint"""
        load_dir: Path = self.config.load_dir
        if load_dir is not None:
            load_step = self.config.load_step
            if load_step is None:
                print("Loading latest checkpoint from load_dir")
                # NOTE: this is specific to the checkpoint name format
                load_step = sorted(int(x[x.find("-") + 1 : x.find(".")]) for x in os.listdir(load_dir))[-1]
            load_path: Path = load_dir / f"step-{load_step:09d}.ckpt"
            assert load_path.exists(), f"Checkpoint {load_path} does not exist"
            loaded_state = torch.load(load_path, map_location="cpu")
            self._start_step = loaded_state["step"] + 1
            # load the checkpoints for pipeline, optimizers, and gradient scalar
            self.pipeline.load_pipeline(loaded_state["pipeline"], loaded_state["step"])
            self.optimizers.load_optimizers(loaded_state["optimizers"])
            self.grad_scaler.load_state_dict(loaded_state["scalers"])
            CONSOLE.print(f"done loading checkpoint from {load_path}")
        else:
            CONSOLE.print("No checkpoints to load, training from scratch")

    @check_main_thread
    def save_checkpoint(self, step: int) -> None:
        """Save the model and optimizers

        Args:
            step: number of steps in training for given checkpoint
        """
        # possibly make the checkpoint directory
        if not self.checkpoint_dir.exists():
            self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        # save the checkpoint
        ckpt_path: Path = self.checkpoint_dir / f"step-{step:09d}.ckpt"
        torch.save(
            {
                "step": step,
                "pipeline": self.pipeline.module.state_dict()  # type: ignore
                if hasattr(self.pipeline, "module")
                else self.pipeline.state_dict(),
                "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()},
                "scalers": self.grad_scaler.state_dict(),
            },
            ckpt_path,
        )
        # possibly delete old checkpoints
        if self.config.save_only_latest_checkpoint:
            # delete everything else in the checkpoint folder
            for f in self.checkpoint_dir.glob("*"):
                if f != ckpt_path:
                    f.unlink()

    @profiler.time_function
    def train_iteration(self, step: int) -> TRAIN_INTERATION_OUTPUT:
        """Run one iteration with a batch of inputs. Returns dictionary of model losses.

        Args:
            step: Current training step.
        """
        self.optimizers.zero_grad_all()
        cpu_or_cuda_str: str = self.device.split(":")[0]
        with torch.autocast(device_type=cpu_or_cuda_str, enabled=self.mixed_precision):
            _, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
            loss = functools.reduce(torch.add, loss_dict.values())
        self.grad_scaler.scale(loss).backward()  # type: ignore
        self.optimizers.optimizer_scaler_step_all(self.grad_scaler)

        if self.config.log_gradients:
            total_grad = 0
            for tag, value in self.pipeline.model.named_parameters():
                assert tag != "Total"
                if value.grad is not None:
                    grad = value.grad.norm()
                    metrics_dict[f"Gradients/{tag}"] = grad
                    total_grad += grad

            metrics_dict["Gradients/Total"] = total_grad

        self.grad_scaler.update()
        self.optimizers.scheduler_step_all(step)

        # Merging loss and metrics dict into a single output.
        return loss, loss_dict, metrics_dict

    @check_eval_enabled
    @profiler.time_function
    def eval_iteration(self, step: int) -> None:
        """Run one iteration with different batch/image/all image evaluations depending on step size.

        Args:
            step: Current training step.
        """
        # a batch of eval rays
        if step_check(step, self.config.steps_per_eval_batch):
            _, eval_loss_dict, eval_metrics_dict = self.pipeline.get_eval_loss_dict(step=step)
            eval_loss = functools.reduce(torch.add, eval_loss_dict.values())
            writer.put_scalar(name="Eval Loss", scalar=eval_loss, step=step)
            writer.put_dict(name="Eval Loss Dict", scalar_dict=eval_loss_dict, step=step)
            writer.put_dict(name="Eval Metrics Dict", scalar_dict=eval_metrics_dict, step=step)

        # one eval image
        if step_check(step, self.config.steps_per_eval_image):
            with TimeWriter(writer, EventName.TEST_RAYS_PER_SEC, write=False) as test_t:
                metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
            writer.put_time(
                name=EventName.TEST_RAYS_PER_SEC,
                duration=metrics_dict["num_rays"] / test_t.duration,
                step=step,
                avg_over_steps=True,
            )
            writer.put_dict(name="Eval Images Metrics", scalar_dict=metrics_dict, step=step)
            group = "Eval Images"
            for image_name, image in images_dict.items():
                writer.put_image(name=group + "/" + image_name, image=image, step=step)

        # all eval images
        if step_check(step, self.config.steps_per_eval_all_images):
            metrics_dict = self.pipeline.get_average_eval_image_metrics(step=step)
            writer.put_dict(name="Eval Images Metrics Dict (all images)", scalar_dict=metrics_dict, step=step)


@dataclass
class VideoTrainerConfig(ExperimentConfig):
    """Configuration for training regimen"""

    _target: Type = field(default_factory=lambda: VideoTrainer)
    """target class to instantiate"""
    steps_per_save: int = 1000
    """Number of steps between saves."""
    steps_per_eval_batch: int = 500
    """Number of steps between randomly sampled batches of rays."""
    steps_per_eval_image: int = 500
    """Number of steps between single eval images."""
    steps_per_eval_all_images: int = 25000
    """Number of steps between eval all images."""
    max_num_iterations: int = 30000
    """Maximum number of iterations to run."""
    mixed_precision: bool = False
    """Whether or not to use mixed precision for training."""
    save_only_latest_checkpoint: bool = True
    """Whether to only save the latest checkpoint or all checkpoints."""
    # optional parameters if we want to resume training
    load_dir: Optional[Path] = None
    """Optionally specify a pre-trained model directory to load from."""
    load_step: Optional[int] = None
    """Optionally specify model step to load from; if none, will find most recent model in load_dir."""
    load_config: Optional[Path] = None
    """Path to config YAML file."""
    log_gradients: bool = False
    """Optionally log gradients during training"""

    pipeline: VideoPipelineConfig = VideoPipelineConfig()
    num_static_iterations: int = 30000
    num_dynamic_iterations: int = 10000
    num_dynamic_frames: int = 1

    static_steps_per_eval_batch: int = 500
    static_steps_per_eval_image: int = 1000
    static_steps_per_eval_all_images: int = 30000
    # dynamic_steps_per_eval_batch: int = 500
    # dynamic_steps_per_eval_image: int = 500
    # dynamic_steps_per_eval_all_images: int = 25000
    all_steps_per_eval_batch: int = 500
    all_steps_per_eval_image: int = 500
    all_steps_per_eval_all_images: int = 30000

    log_train_t: bool = False
    ## Whether logging num ray per sec when training
    skip_static: bool = False
    """set to True if you want skip training a static model"""
    static_model_path: Optional[Union[str, Path]] = None
    disable_mlp: bool = True
    skip_dynamic: bool = False

    upsample_enable: bool = False
    upsample_steps: List[int] = field(
        default_factory=lambda: [500, 1000, 1500, 2000]
    )
    resolution_list: List[int] = field(
        default_factory=lambda: [16, 64, 256, 1024]
    )
    tv_loss_weight: float = 0.001
    step_scale: int = 20
    reset_proposal_samplers: bool = False
    hash_reinit_std: float = 1e-4

class VideoTrainer(Trainer):
    config: VideoTrainerConfig
    pipeline: VideoPipeline
    optimizers: Optimizers
    callbacks: List[TrainingCallback]

    def __init__(self, config: VideoTrainerConfig, local_rank: int = 0, world_size: int = 1) -> None:
        self.config = config
        self.local_rank = local_rank
        self.world_size = world_size
        self.device: TORCH_DEVICE = "cpu" if world_size == 0 else f"cuda:{local_rank}"
        self.mixed_precision: bool = self.config.mixed_precision
        if self.device == "cpu":
            self.mixed_precision = False
            CONSOLE.print("Mixed precision is disabled for CPU training.")
        self._start_step: int = 0
        # optimizers
        self.grad_scaler = GradScaler(enabled=self.mixed_precision)

        self.base_dir: Path = config.get_base_dir()
        # directory to save checkpoints
        self.checkpoint_dir: Path = config.get_checkpoint_dir()
        CONSOLE.log(f"Saving checkpoints to: {self.checkpoint_dir}")

        self.viewer_state = None

    def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
        """Setup the Trainer by calling other setup functions.

        Args:
            test_mode:
                'val': loads train/val datasets into memory
                'test': loads train/test datasets into memory
                'inference': does not load any dataset into memory
        """
        self.pipeline = self.config.pipeline.setup(
            device=self.device, test_mode=test_mode, world_size=self.world_size, local_rank=self.local_rank
        )
        self.optimizers = self.setup_optimizers()

        self._load_checkpoint()

        self.callbacks = self.pipeline.get_training_callbacks(
            TrainingCallbackAttributes(
                optimizers=self.optimizers,  # type: ignore
                grad_scaler=self.grad_scaler,  # type: ignore
                pipeline=self.pipeline,  # type: ignore
            )
        )

        # set up viewer if enabled
        viewer_log_path = self.base_dir / self.config.viewer.relative_log_filename
        self.viewer_state, banner_messages = None, None
        if self.config.is_viewer_enabled() and self.local_rank == 0:
            datapath = self.pipeline.datamanager.get_datapath()
            if datapath is None:
                datapath = self.base_dir
            self.viewer_state, banner_messages = viewer_utils.setup_viewer(
                self.config.viewer, log_filename=viewer_log_path, datapath=datapath
            )
        self._check_viewer_warnings()
        # set up writers/profilers if enabled
        writer_log_path = self.base_dir / self.config.logging.relative_log_dir
        writer.setup_event_writer(
            self.config.is_wandb_enabled(), self.config.is_tensorboard_enabled(), log_dir=writer_log_path
        )
        num_iterations = self.config.num_dynamic_frames + self.config.num_dynamic_iterations
        if not self.config.skip_static:
            num_iterations += self.config.num_static_iterations
        writer.setup_local_writer(self.config.logging, max_iter=num_iterations, banner_messages=banner_messages)
        writer.put_config(name="config", config_dict=dataclasses.asdict(self.config), step=0)
        profiler.setup_profiler(self.config.logging)

    @profiler.time_function
    def train_static_iteration(self, step: int) -> TRAIN_INTERATION_OUTPUT:
        self.optimizers.zero_grad_all()
        cpu_or_cuda_str: str = self.device.split(":")[0]
        with torch.autocast(device_type=cpu_or_cuda_str, enabled=self.mixed_precision):
            _, loss_dict, metrics_dict = self.pipeline.get_static_train_loss_dict(step=step)
            loss = functools.reduce(torch.add, loss_dict.values())
        self.grad_scaler.scale(loss).backward()  # type: ignore
        if self.config.tv_loss_weight > 0:
            self.pipeline.model.field.encoding.grad_total_variation(weight=self.config.tv_loss_weight)
        self.optimizers.optimizer_scaler_step_all(self.grad_scaler)

        if self.config.log_gradients:
            total_grad = 0
            for tag, value in self.pipeline.model.named_parameters():
                assert tag != "Total"
                if value.grad is not None:
                    grad = value.grad.norm()
                    metrics_dict[f"Gradients/{tag}"] = grad
                    total_grad += grad

            metrics_dict["Gradients/Total"] = total_grad

        self.grad_scaler.update()
        self.optimizers.scheduler_step_all(step)

        # Merging loss and metrics dict into a single output.
        return loss, loss_dict, metrics_dict

    @profiler.time_function
    def train_dynamic_iteration(self, step: int) -> TRAIN_INTERATION_OUTPUT:
        self.optimizers.zero_grad_all()
        cpu_or_cuda_str: str = self.device.split(":")[0]
        with torch.autocast(device_type=cpu_or_cuda_str, enabled=self.mixed_precision):
            _, loss_dict, metrics_dict = self.pipeline.get_dynamic_train_loss_dict(step=step)
            loss = functools.reduce(torch.add, loss_dict.values())
        self.grad_scaler.scale(loss).backward()  # type: ignore
        self.optimizers.optimizer_scaler_step_all(self.grad_scaler)

        if self.config.log_gradients:
            total_grad = 0
            for tag, value in self.pipeline.model.named_parameters():
                assert tag != "Total"
                if value.grad is not None:
                    grad = value.grad.norm()
                    metrics_dict[f"Gradients/{tag}"] = grad
                    total_grad += grad

            metrics_dict["Gradients/Total"] = total_grad

        self.grad_scaler.update()
        self.optimizers.scheduler_step_all(step)

        # Merging loss and metrics dict into a single output.
        return loss, loss_dict, metrics_dict

    def eval_static_one_image(self, step, name: str):
        metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
        writer.put_dict(name=name, scalar_dict=metrics_dict, step=step)
        group = name
        for image_name, image in images_dict.items():
            writer.put_image(name=group + "/" + image_name, image=image, step=step)

    @profiler.time_function
    def eval_static_iteration(self, step: int) -> None:
        if step_check(step, self.config.static_steps_per_eval_batch):
            _, eval_loss_dict, eval_metrics_dict = self.pipeline.get_static_eval_loss_dict(step=step)
            eval_loss = functools.reduce(torch.add, eval_loss_dict.values())
            writer.put_scalar(name="Eval Loss", scalar=eval_loss, step=step)
            writer.put_dict(name="Eval Loss Dict", scalar_dict=eval_loss_dict, step=step)
            writer.put_dict(name="Eval Metrics Dict", scalar_dict=eval_metrics_dict, step=step)

        # one eval image
        if step_check(step, self.config.static_steps_per_eval_image):
            with TimeWriter(writer, EventName.TEST_RAYS_PER_SEC, write=False) as test_t:
                metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
            writer.put_time(
                name=EventName.TEST_RAYS_PER_SEC,
                duration=metrics_dict["num_rays"] / test_t.duration,
                step=step,
                avg_over_steps=True,
            )
            writer.put_dict(name="Eval Images Metrics", scalar_dict=metrics_dict, step=step)
            group = "Eval Images"
            for image_name, image in images_dict.items():
                writer.put_image(name=group + "/" + image_name, image=image, step=step)

        # all eval images
        if step_check(step, self.config.static_steps_per_eval_all_images):
            metrics_dict = self.pipeline.get_average_eval_image_metrics(step=step)
            writer.put_dict(name="Eval Images Metrics Dict (all images)", scalar_dict=metrics_dict, step=step)

    @profiler.time_function
    def eval_all_iteration(self, step: int) -> None:
        if step_check(step, self.config.all_steps_per_eval_batch):
            _, eval_loss_dict, eval_metrics_dict = self.pipeline.get_static_eval_loss_dict(step=step)
            eval_loss = functools.reduce(torch.add, eval_loss_dict.values())
            writer.put_scalar(name="Eval Loss", scalar=eval_loss, step=step)
            writer.put_dict(name="Eval Loss Dict", scalar_dict=eval_loss_dict, step=step)
            writer.put_dict(name="Eval Metrics Dict", scalar_dict=eval_metrics_dict, step=step)

        # one eval image
        if step_check(step, self.config.all_steps_per_eval_image):
            with TimeWriter(writer, EventName.TEST_RAYS_PER_SEC, write=False) as test_t:
                metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
            writer.put_time(
                name=EventName.TEST_RAYS_PER_SEC,
                duration=metrics_dict["num_rays"] / test_t.duration,
                step=step,
                avg_over_steps=True,
            )
            writer.put_dict(name="Eval Images Metrics", scalar_dict=metrics_dict, step=step)
            group = "Eval Images"
            for image_name, image in images_dict.items():
                writer.put_image(name=group + "/" + image_name, image=image, step=step)

            writer.put_image(
                name="Eval Images/gt_mask",
                image=self.pipeline.get_cur_frame_eval_mask().squeeze().unsqueeze(-1),
                step=step,
            )

        # all eval images
        if step_check(step, self.config.all_steps_per_eval_all_images):
            metrics_dict = self.pipeline.get_average_eval_image_metrics(step=step)
            writer.put_dict(name="Eval Images Metrics Dict (all images)", scalar_dict=metrics_dict, step=step)

    def eval_after_one_frame(self, step: int) -> None:
        """Evaluating after training on one frame finished, do: (1) put rendered images (2) put mixed images"""
        CONSOLE.print(f"Evaluating on Frame {self.pipeline.cur_frame}")
        metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
        writer.put_dict(
            name=f"Eval Images Metrics Frame {self.pipeline.cur_frame}", scalar_dict=metrics_dict, step=step
        )
        group = f"Eval Images Frame {self.pipeline.cur_frame}"
        for image_name, image in images_dict.items():
            writer.put_image(name=group + "/" + image_name, image=image, step=step)

        writer.put_image(
            name="Eval Images/gt_mask",
            image=self.pipeline.get_cur_frame_eval_mask().squeeze().unsqueeze(-1),
            step=step,
        )

        mask = self.pipeline.get_cur_frame_eval_mask().squeeze().unsqueeze(-1).to(images_dict["img"])
        last_frame = self.pipeline.get_eval_last_frame().to(images_dict["img"])
        """Draw a mixed version of final image: mask * trained + (1 - mask) * last_frame"""
        size = images_dict["img"].size(1) // 2
        rgb = torch.split(images_dict["img"], size, dim=1)[1]
        mixed_image = mask * rgb + (1 - mask) * last_frame

        writer.put_image(name="Eval Images/mixed img", image=mixed_image, step=step)

    # @profiler.time_function
    # def eval_dynamic_iteration(self, step: int) -> None:
    #     # if step_check(step, self.config.)
    #     if step_check(step, self.config.dynamic_steps_per_eval_batch):
    #         _, eval_loss_dict, eval_metrics_dict = self.pipeline.get_dynamic_eval_loss_dict(step=step)
    #         eval_loss = functools.reduce(torch.add, eval_loss_dict.values())
    #         writer.put_scalar(name="Eval Loss", scalar=eval_loss, step=step)
    #         writer.put_dict(name="Eval Loss Dict", scalar_dict=eval_loss_dict, step=step)
    #         writer.put_dict(name="Eval Metrics Dict", scalar_dict=eval_metrics_dict, step=step)

    #     # one eval image
    #     if step_check(step, self.config.dynamic_steps_per_eval_image):
    #         with TimeWriter(writer, EventName.TEST_RAYS_PER_SEC, write=False) as test_t:
    #             metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
    #         writer.put_time(
    #             name=EventName.TEST_RAYS_PER_SEC,
    #             duration=metrics_dict["num_rays"] / test_t.duration,
    #             step=step,
    #             avg_over_steps=True,
    #         )
    #         writer.put_dict(name="Eval Images Metrics", scalar_dict=metrics_dict, step=step)
    #         group = "Eval Images"
    #         for image_name, image in images_dict.items():
    #             writer.put_image(name=group + "/" + image_name, image=image, step=step)

    def save_static_checkpoint(self, step: int):
        """Save the model and optimizers

        Args:
            step: number of steps in training for given checkpoint
        """
        # possibly make the checkpoint directory
        if not self.checkpoint_dir.exists():
            self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        # save the checkpoint
        ckpt_path: Path = self.checkpoint_dir / f"static-step-{step:09d}.ckpt"
        torch.save(
            {
                "step": step,
                "pipeline": self.pipeline.module.state_dict()  # type: ignore
                if hasattr(self.pipeline, "module")
                else self.pipeline.state_dict(),
                "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()},
                "scalers": self.grad_scaler.state_dict(),
            },
            ckpt_path,
        )
        # possibly delete old checkpoints
        if self.config.save_only_latest_checkpoint:
            # delete everything else in the checkpoint folder
            for f in self.checkpoint_dir.glob("*"):
                if f != ckpt_path:
                    f.unlink()

    def save_dynamic_checkpoint(self, step: int):
        """Save the model and optimizers

        Args:
            step: number of steps in training for given checkpoint
        """
        # possibly make the checkpoint directory
        if not self.checkpoint_dir.exists():
            self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        # save the checkpoint
        ckpt_path: Path = self.checkpoint_dir / f"dynamic-step-{step:09d}.ckpt"
        torch.save(
            {
                "step": step,
                "pipeline": self.pipeline.module.state_dict()  # type: ignore
                if hasattr(self.pipeline, "module")
                else self.pipeline.state_dict(),
                "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()},
                "scalers": self.grad_scaler.state_dict(),
                "cur_frame": self.pipeline.cur_frame,
            },
            ckpt_path,
        )
        # possibly delete old checkpoints
        if self.config.save_only_latest_checkpoint:
            # delete everything else in the checkpoint folder
            for f in self.checkpoint_dir.glob("*"):
                if f != ckpt_path:
                    f.unlink()

    def train_static(self):
        assert self.pipeline.datamanager.train_dataset is not None, "Missing DatsetInputs"

        self.pipeline.datamanager.train_dataparser_outputs.save_dataparser_transform(
            self.base_dir / "dataparser_transforms.json"
        )

        self._init_viewer_state()

        CONSOLE.print("begin to train static part")

        with TimeWriter(writer, EventName.STATIC_TOTAL_TRAIN_TIME):
            num_iterations = self.config.num_static_iterations
            step = 0
            for step in range(num_iterations):
                with TimeWriter(writer, EventName.STATIC_ITER_TRAIN_TIME, step=step) as train_t:
                    self.pipeline.train()

                    # training callbacks before the training iteration
                    for callback in self.callbacks:
                        callback.run_callback_at_location(
                            step, location=TrainingCallbackLocation.BEFORE_TRAIN_ITERATION
                        )

                    loss, loss_dict, metrics_dict = self.train_static_iteration(step)

                    for callback in self.callbacks:
                        callback.run_callback_at_location(step, location=TrainingCallbackLocation.AFTER_TRAIN_ITERATION)

                    if self.config.upsample_enable and step in self.config.upsample_steps:
                        upsample_idx = self.config.upsample_steps.index(step)
                        upsample_resolution = self.config.resolution_list[upsample_idx]
                        CONSOLE.log(f"upsample resolution to {upsample_resolution}")
                        self.pipeline.model.field.upsample(upsample_resolution)
                        CONSOLE.log(f"reset optimizers to include upsampled parameters")
                        self.optimizers = self.setup_optimizers(reset_step=step)
                        torch.cuda.empty_cache()

                if step > 1 and self.config.log_train_t:
                    writer.put_time(
                        name=EventName.TRAIN_RAYS_PER_SEC,
                        duration=self.pipeline.datamanager.get_train_rays_per_batch() / train_t.duration,
                        step=step,
                        avg_over_steps=True,
                    )

                self._update_viewer_state(step)

                # a batch of train rays
                if step_check(step, self.config.logging.steps_per_log, run_at_zero=True):
                    writer.put_scalar(name="Train Loss", scalar=loss, step=step)
                    writer.put_dict(name="Train Loss Dict", scalar_dict=loss_dict, step=step)
                    writer.put_dict(name="Train Metrics Dict", scalar_dict=metrics_dict, step=step)

                # Do not perform evaluation if there are no validation images
                if self.pipeline.datamanager.eval_dataset:
                    self.pipeline.eval()
                    self.eval_static_iteration(step)

                if step_check(step, self.config.steps_per_save):
                    self.save_checkpoint(step)

                writer.write_out_storage()

        # save checkpoint at the end of training for static checkpoint ()
        self.save_static_checkpoint(step)

    def train_dynamic(self):
        self._reset_optimizer()
        CONSOLE.print("begin to train dynamic part")
        static_step_offset = 0 if self.config.skip_static else self.config.num_static_iterations
        step = 0
        self.eval_static_one_image(step, "Static Model")
        with TimeWriter(writer, EventName.DYNAMIC_TOTAL_TRAIN_TIME):
            for i in range(self.config.num_dynamic_frames):
                # tick to next frame
                self.pipeline.tick()

                ## test mask
                # mask = self.pipeline.datamanager.eval_dataset.mask.squeeze(dim=0).cpu().numpy()
                # CONSOLE.print(np.count_nonzero(mask))

                CONSOLE.print(f"cur frame: {self.pipeline.cur_frame}")
                CONSOLE.print(f"num dynamic rays: {self.pipeline.num_dynamic_rays}")


                # disable static voxels updating
                CONSOLE.print(f"num static rays: {self.pipeline.num_static_rays}")
                CONSOLE.print(f"num static set iterations: {self.pipeline.num_static_rays//self.pipeline.datamanager.get_train_rays_per_batch()}")

                with Timer(des="setting static"):
                    for i_reinit in range(self.pipeline.num_static_rays//self.pipeline.datamanager.get_train_rays_per_batch()//self.config.step_scale):
                        # CONSOLE.print(i_reinit)
                        self.pipeline.set_static(i_reinit)
                CONSOLE.print(f"setting stat: {(self.pipeline.model.field.encoding.grid_mask==0).sum()}/{self.pipeline.model.field.encoding.grid_mask.shape[0]}")

                # reset the hash values
                CONSOLE.print(f"resetting grid values for {self.pipeline.num_dynamic_rays//self.pipeline.datamanager.get_train_rays_per_batch()} iterations")
                for i_reinit in range(self.pipeline.num_dynamic_rays//self.pipeline.datamanager.get_train_rays_per_batch()):
                    # CONSOLE.print(i_reinit)
                    self.pipeline.hash_reinitialize(step=i_reinit, std=self.config.hash_reinit_std)

                if self.config.reset_proposal_samplers:
                    CONSOLE.print(f"reset parameters for proposal network")
                    self.pipeline.model.reset_proposal_samplers()

                num_iterations = self.config.num_dynamic_iterations
                CONSOLE.print(f"num of iteration: {num_iterations}")
                for step in range(num_iterations):
                    # event_step = step + num_iterations * i + self.config.num_static_iterations
                    with TimeWriter(writer, EventName.DYNAMIC_ITER_TRAIN_TIME, step=step) as train_t:
                        self.pipeline.train()

                        # training callbacks before the training iteration
                        for callback in self.callbacks:
                            callback.run_callback_at_location(
                                step, location=TrainingCallbackLocation.BEFORE_TRAIN_ITERATION
                            )

                        loss, loss_dict, metrics_dict = self.train_dynamic_iteration(step)
                        # print(list(self.pipeline.model.field.mlp_head.parameters()))

                        for callback in self.callbacks:
                            callback.run_callback_at_location(
                                step, location=TrainingCallbackLocation.AFTER_TRAIN_ITERATION
                            )

                    if step > 1 and self.config.log_train_t:
                        writer.put_time(
                            name=EventName.TRAIN_RAYS_PER_SEC,
                            duration=self.pipeline.datamanager.get_train_rays_per_batch() / train_t.duration,
                            step=step + static_step_offset,
                            avg_over_steps=True,
                        )

                    self._update_viewer_state(step)

                    # a batch of train rays
                    if step_check(step, self.config.logging.steps_per_log, run_at_zero=True):
                        writer.put_scalar(name="Train Loss", scalar=loss, step=step + static_step_offset)
                        writer.put_dict(name="Train Loss Dict", scalar_dict=loss_dict, step=step + static_step_offset)
                        writer.put_dict(
                            name="Train Metrics Dict", scalar_dict=metrics_dict, step=step + static_step_offset
                        )

                    # Do not perform evaluation if there are no validation images
                    if self.pipeline.datamanager.eval_dataset:
                        # self.eval_dynamic_iteration(step)
                        self.eval_all_iteration(step)

                    if step_check(step, self.config.steps_per_save):
                        self.save_checkpoint(step)

                    writer.write_out_storage()

                self.eval_after_one_frame(step)

        # write out any remaining events (e.g., total train time)

        self.save_dynamic_checkpoint(step)
        writer.write_out_storage()

        CONSOLE.rule()
        CONSOLE.print("[bold green]:tada: :tada: :tada: Training Finished :tada: :tada: :tada:", justify="center")
        if not self.config.viewer.quit_on_train_completion:
            CONSOLE.print("Use ctrl+c to quit", justify="center")
            self._always_render(step)

    def train(self, skip_static=False) -> None:
        # """Train the video model"""
        # assert self.pipeline.datamanager.train_dataset is not None, "Missing DatsetInputs"

        # self.pipeline.datamanager.train_dataparser_outputs.save_dataparser_transform(
        #     self.base_dir / "dataparser_transforms.json"
        # )

        # self._init_viewer_state()

        # if not skip_static:
        #     ## train static model on first frame
        #     with TimeWriter(writer, EventName.STATIC_TOTAL_TRAIN_TIME):
        #         num_iterations = self.config.num_static_iterations
        #         step = 0
        #         for step in range(num_iterations):
        #             with TimeWriter(writer, EventName.STATIC_ITER_TRAIN_TIME, step=step) as train_t:
        #                 self.pipeline.train()

        #                 # training callbacks before the training iteration
        #                 for callback in self.callbacks:
        #                     callback.run_callback_at_location(
        #                         step, location=TrainingCallbackLocation.BEFORE_TRAIN_ITERATION
        #                     )

        #                 loss, loss_dict, metrics_dict = self.train_static_iteration(step)

        #                 for callback in self.callbacks:
        #                     callback.run_callback_at_location(
        #                         step, location=TrainingCallbackLocation.AFTER_TRAIN_ITERATION
        #                     )

        #             if step > 1:
        #                 writer.put_time(
        #                     name=EventName.TRAIN_RAYS_PER_SEC,
        #                     duration=self.pipeline.datamanager.get_train_rays_per_batch() / train_t.duration,
        #                     step=step,
        #                     avg_over_steps=True,
        #                 )

        #             self._update_viewer_state(step)

        #             # a batch of train rays
        #             if step_check(step, self.config.logging.steps_per_log, run_at_zero=True):
        #                 writer.put_scalar(name="Train Loss", scalar=loss, step=step)
        #                 writer.put_dict(name="Train Loss Dict", scalar_dict=loss_dict, step=step)
        #                 writer.put_dict(name="Train Metrics Dict", scalar_dict=metrics_dict, step=step)

        #             # Do not perform evaluation if there are no validation images
        #             if self.pipeline.datamanager.eval_dataset:
        #                 self.eval_static_iteration(step)

        #             if step_check(step, self.config.steps_per_save):
        #                 self.save_checkpoint(step)

        #             writer.write_out_storage()

        #     # save checkpoint at the end of training for static checkpoint ()
        #     self.save_static_checkpoint(step)

        # # train dynamic part of the model
        # step = 0
        # with TimeWriter(writer, EventName.DYNAMIC_TOTAL_TRAIN_TIME):
        #     for i in range(self.config.num_dynamic_frames):
        #         CONSOLE.print(f"begin train on frame {i}")
        #         # tick to next frame
        #         self.pipeline.tick()
        #         num_iterations = self.config.num_dynamic_iterations
        #         for step in range(num_iterations):
        #             with TimeWriter(writer, EventName.DYNAMIC_ITER_TRAIN_TIME, step=step) as train_t:
        #                 self.pipeline.train()

        #                 # training callbacks before the training iteration
        #                 for callback in self.callbacks:
        #                     callback.run_callback_at_location(
        #                         step, location=TrainingCallbackLocation.BEFORE_TRAIN_ITERATION
        #                     )

        #                 loss, loss_dict, metrics_dict = self.train_dynamic_iteration(step)

        #                 for callback in self.callbacks:
        #                     callback.run_callback_at_location(
        #                         step, location=TrainingCallbackLocation.AFTER_TRAIN_ITERATION
        #                     )

        #             if step > 1:
        #                 writer.put_time(
        #                     name=EventName.TRAIN_RAYS_PER_SEC,
        #                     duration=self.pipeline.datamanager.get_train_rays_per_batch() / train_t.duration,
        #                     step=step,
        #                     avg_over_steps=True,
        #                 )

        #             self._update_viewer_state(step)

        #             # a batch of train rays
        #             if step_check(step, self.config.logging.steps_per_log, run_at_zero=True):
        #                 writer.put_scalar(name="Train Loss", scalar=loss, step=step)
        #                 writer.put_dict(name="Train Loss Dict", scalar_dict=loss_dict, step=step)
        #                 writer.put_dict(name="Train Metrics Dict", scalar_dict=metrics_dict, step=step)

        #             # Do not perform evaluation if there are no validation images
        #             if self.pipeline.datamanager.eval_dataset:
        #                 self.eval_dynamic_iteration(step)
        #                 self.eval_all_iteration(step)

        #             if step_check(step, self.config.steps_per_save):
        #                 self.save_checkpoint(step)

        #             writer.write_out_storage()

        # # write out any remaining events (e.g., total train time)

        # self.save_dynamic_checkpoint(step)
        # writer.write_out_storage()

        # CONSOLE.rule()
        # CONSOLE.print("[bold green]:tada: :tada: :tada: Training Finished :tada: :tada: :tada:", justify="center")
        # if not self.config.viewer.quit_on_train_completion:
        #     CONSOLE.print("Use ctrl+c to quit", justify="center")
        #     self._always_render(step)
        if not self.config.skip_static:
            self.train_static()
        else:
            assert (
                self.config.static_model_path is not None
            ), "Skip training of static part requires providing well-trained static model"
            self.load_static_checkpoint()
            CONSOLE.print("Skip static paer training")

        self.train_dynamic()

    def load_static_checkpoint(self):
        static_model_path = self.config.static_model_path
        if static_model_path is not None:
            loaded = torch.load(static_model_path, map_location="cpu")
            self.pipeline.load_pipeline(loaded["pipeline"], 0)

    def load_from_static(self):
        # TODO: add load from static model and continue training, make sure the datasets are ready
        self._load_checkpoint()

    def _reset_optimizer(self):
        CONSOLE.print("re-initialzing optimizers")
        if self.config.disable_mlp:
            self.pipeline.model.disable_MLP()
        self.optimizers = self.setup_optimizers()
